# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Code for model cloning, plus model-related API entries.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine import saving
from tensorflow.python.keras._impl.keras.engine import sequential
from tensorflow.python.keras._impl.keras.engine import training
from tensorflow.python.keras._impl.keras.engine.input_layer import Input
from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer
from tensorflow.python.keras._impl.keras.utils import generic_utils
from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg


# API entries importable from `keras.models`:
Model = training.Model  # pylint: disable=invalid-name
Sequential = sequential.Sequential  # pylint: disable=invalid-name
save_model = saving.save_model
load_model = saving.load_model
model_from_config = saving.model_from_config
model_from_yaml = saving.model_from_yaml
model_from_json = saving.model_from_json


def _clone_functional_model(model, input_tensors=None):
  """Clone a functional `Model` instance.

  Model cloning is similar to calling a model on new inputs,
  except that it creates new layers (and thus new weights) instead
  of sharing the weights of the existing layers.

  Arguments:
      model: Instance of `Model`.
      input_tensors: optional list of input tensors
          to build the model upon. If not provided,
          placeholders will be created.

  Returns:
      An instance of `Model` reproducing the behavior
      of the original model, on top of new inputs tensors,
      using newly instantiated weights.

  Raises:
      ValueError: in case of invalid `model` argument value.
  """
  if not isinstance(model, Model):
    raise ValueError('Expected `model` argument '
                     'to be a `Model` instance, got ', model)
  if isinstance(model, Sequential):
    raise ValueError('Expected `model` argument '
                     'to be a functional `Model` instance, '
                     'got a `Sequential` instance instead:', model)

  layer_map = {}  # Cache for created layers.
  tensor_map = {}  # Map {reference_tensor: (corresponding_tensor, mask)}
  if input_tensors is None:
    # Create placeholders to build the model on top of.
    input_layers = []
    input_tensors = []
    for layer in model._input_layers:
      input_tensor = Input(
          batch_shape=layer._batch_input_shape,
          dtype=layer.dtype,
          sparse=layer.sparse,
          name=layer.name)
      input_tensors.append(input_tensor)
      # Cache newly created input layer.
      newly_created_input_layer = input_tensor._keras_history[0]
      layer_map[layer] = newly_created_input_layer
    for original_input_layer, cloned_input_layer in zip(model._input_layers,
                                                        input_layers):
      layer_map[original_input_layer] = cloned_input_layer
  else:
    # Make sure that all input tensors come from a Keras layer.
    # If tensor comes from an input layer: cache the input layer.
    input_tensors = generic_utils.to_list(input_tensors)
    input_tensors_ = []
    for i, x in enumerate(input_tensors):
      if not K.is_keras_tensor(x):
        name = model._input_layers[i].name
        input_tensor = Input(tensor=x, name='input_wrapper_for_' + name)
        input_tensors_.append(input_tensor)
        # Cache newly created input layer.
        original_input_layer = x._keras_history[0]
        newly_created_input_layer = input_tensor._keras_history[0]
        layer_map[original_input_layer] = newly_created_input_layer
      else:
        input_tensors_.append(x)
    input_tensors = input_tensors_

  for x, y in zip(model.inputs, input_tensors):
    tensor_map[x] = (y, None)  # tensor, mask

  # Iterated over every node in the reference model, in depth order.
  depth_keys = list(model._nodes_by_depth.keys())
  depth_keys.sort(reverse=True)
  for depth in depth_keys:
    nodes = model._nodes_by_depth[depth]
    for node in nodes:
      # Recover the corresponding layer.
      layer = node.outbound_layer

      # Get or create layer.
      if layer not in layer_map:
        # Clone layer.
        new_layer = layer.__class__.from_config(layer.get_config())
        layer_map[layer] = new_layer
        layer = new_layer
      else:
        # Reuse previously cloned layer.
        layer = layer_map[layer]
        # Don't call InputLayer multiple times.
        if isinstance(layer, InputLayer):
          continue

      # Gather inputs to call the new layer.
      referenceinput_tensors_ = node.input_tensors
      reference_output_tensors = node.output_tensors

      # If all previous input tensors are available in tensor_map,
      # then call node.inbound_layer on them.
      computed_data = []  # List of tuples (input, mask).
      for x in referenceinput_tensors_:
        if x in tensor_map:
          computed_data.append(tensor_map[x])

      if len(computed_data) == len(referenceinput_tensors_):
        # Call layer.
        if node.arguments:
          kwargs = node.arguments
        else:
          kwargs = {}
        if len(computed_data) == 1:
          computed_tensor, computed_mask = computed_data[0]
          if has_arg(layer.call, 'mask'):
            if 'mask' not in kwargs:
              kwargs['mask'] = computed_mask
          output_tensors = generic_utils.to_list(layer(computed_tensor,
                                                       **kwargs))
          output_masks = generic_utils.to_list(
              layer.compute_mask(computed_tensor, computed_mask))
          computed_tensors = [computed_tensor]
          computed_masks = [computed_mask]
        else:
          computed_tensors = [x[0] for x in computed_data]
          computed_masks = [x[1] for x in computed_data]
          if has_arg(layer.call, 'mask'):
            if 'mask' not in kwargs:
              kwargs['mask'] = computed_masks
          output_tensors = generic_utils.to_list(layer(computed_tensors,
                                                       **kwargs))
          output_masks = generic_utils.to_list(
              layer.compute_mask(computed_tensors, computed_masks))
        # Update tensor_map.
        for x, y, mask in zip(reference_output_tensors, output_tensors,
                              output_masks):
          tensor_map[x] = (y, mask)

  # Check that we did compute the model outputs,
  # then instantiate a new model from inputs and outputs.
  output_tensors = []
  for x in model.outputs:
    assert x in tensor_map, 'Could not compute output ' + str(x)
    tensor, _ = tensor_map[x]
    output_tensors.append(tensor)
  return Model(input_tensors, output_tensors, name=model.name)


def _clone_sequential_model(model, input_tensors=None):
  """Clone a `Sequential` model instance.

  Model cloning is similar to calling a model on new inputs,
  except that it creates new layers (and thus new weights) instead
  of sharing the weights of the existing layers.

  Arguments:
      model: Instance of `Sequential`.
      input_tensors: optional list of input tensors
          to build the model upon. If not provided,
          placeholders will be created.

  Returns:
      An instance of `Sequential` reproducing the behavior
      of the original model, on top of new inputs tensors,
      using newly instantiated weights.

  Raises:
      ValueError: in case of invalid `model` argument value.
  """
  if not isinstance(model, Sequential):
    raise ValueError('Expected `model` argument '
                     'to be a `Sequential` model instance, '
                     'but got:', model)

  def clone(layer):
    return layer.__class__.from_config(layer.get_config())

  layers = [clone(layer) for layer in model.layers]
  if input_tensors is None:
    return Sequential(layers=layers, name=model.name)
  else:
    if len(generic_utils.to_list(input_tensors)) != 1:
      raise ValueError('To clone a `Sequential` model, we expect '
                       ' at most one tensor '
                       'as part of `input_tensors`.')
    x = generic_utils.to_list(input_tensors)[0]
    if K.is_keras_tensor(x):
      origin_layer = x._keras_history[0]
      if isinstance(origin_layer, InputLayer):
        return Sequential(layers=[origin_layer] + layers, name=model.name)
      else:
        raise ValueError('Cannot clone a `Sequential` model on top '
                         'of a tensor that comes from a Keras layer '
                         'other than an `InputLayer`. '
                         'Use the functional API instead.')
    input_tensor = Input(tensor=x, name='input_wrapper_for_' + str(x.name))
    input_layer = input_tensor._keras_history[0]
    return Sequential(layers=[input_layer] + layers, name=model.name)


def clone_model(model, input_tensors=None):
  """Clone any `Model` instance.

  Model cloning is similar to calling a model on new inputs,
  except that it creates new layers (and thus new weights) instead
  of sharing the weights of the existing layers.

  Arguments:
      model: Instance of `Model`
          (could be a functional model or a Sequential model).
      input_tensors: optional list of input tensors
          to build the model upon. If not provided,
          placeholders will be created.

  Returns:
      An instance of `Model` reproducing the behavior
      of the original model, on top of new inputs tensors,
      using newly instantiated weights.

  Raises:
      ValueError: in case of invalid `model` argument value.
  """
  if isinstance(model, Sequential):
    return _clone_sequential_model(model, input_tensors=input_tensors)
  else:
    return _clone_functional_model(model, input_tensors=input_tensors)
